from __future__ import annotations

import itertools
import json
import random
import re
import string
from collections import defaultdict
from typing import TYPE_CHECKING

import requests
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QMessageBox
from sortedcontainers import SortedDict

from angrmanagement.config import Conf
from angrmanagement.plugins.base_plugin import BasePlugin

if TYPE_CHECKING:
    from angrmanagement.ui.workspace import Workspace


class VaRec(BasePlugin):
    """
    The plugin for supporting the VaRec plugin (private for now until it is released to the public).
    """

    def __init__(self, workspace: Workspace) -> None:
        super().__init__(workspace)

        self.transitions: set[tuple[int, int]] = set()
        self.covered_blocks = SortedDict()

        self.sink_color = Qt.GlobalColor.yellow

    MENU_BUTTONS = [
        "&Infer variable names",
    ]
    INFER_VARIABLE_NAMES = 0

    def handle_click_menu(self, idx: int) -> None:
        if idx < 0 or idx >= len(self.MENU_BUTTONS):
            return

        if self.workspace.main_instance.project.am_none:
            return

        mapping = {
            VaRec.INFER_VARIABLE_NAMES: self.infer_variable_names,
        }

        mapping.get(idx)()

    @staticmethod
    def _restore_stage(view) -> None:
        # shrug
        for v in view.codegen._variable_kb.variables[view.function.addr]._unified_variables:
            m = re.match(r"@@(\S+)@@(\S+)@@", v.name)
            if m is not None:
                var_name = m.group(1)
                v.name = var_name
        # refresh the view
        view.codegen.regenerate_text()
        view.codegen.am_event()

    @staticmethod
    def randstr(n: int = 8):
        return "".join(random.choice(string.ascii_lowercase) for _ in range(n))

    def infer_variable_names(self) -> None:
        view = self.workspace._get_or_create_pseudocode_view()
        if view.codegen.am_none:
            QMessageBox.critical(
                self.workspace._main_window,
                "Error in variable name prediction",
                "Cannot predict variable names. No pseudocode exists in the pseudocode view.",
                QMessageBox.StandardButton.Ok,
            )
            return
        if view.codegen._variable_kb is None:
            QMessageBox.critical(
                self.workspace._main_window,
                "Error in variable name prediction",
                "Cannot predict variable names. The pseudocode view does not have associated variables KB.",
                QMessageBox.StandardButton.Ok,
            )
            return

        proxies = {"http": Conf.http_proxy, "https": Conf.https_proxy} if Conf.http_proxy or Conf.https_proxy else None

        for v in view.codegen._variable_kb.variables[view.function.addr]._unified_variables:
            if not v.renamed:
                v.name = f"@@{v.name}@@{VaRec.randstr()}@@"

        view.codegen.regenerate_text()
        d = {
            "code": [
                {
                    "raw_codes": [
                        view.codegen.text,
                    ]
                }
            ]
        }
        r = requests.post(f"{Conf.varec_endpoint}", data=json.dumps(d), proxies=proxies)
        try:
            result = json.loads(r.text)
        except json.JSONDecodeError:
            self._restore_stage(view)

            QMessageBox.critical(
                self.workspace._main_window,
                "Error in variable name prediction",
                "Failed to predict names for all variables involved.",
                QMessageBox.StandardButton.Ok,
            )

            return

        varname_blacklist = {"UNK", "null", "true", "false", "return", "do", "while"}
        varname_to_predicted = defaultdict(list)

        # handle failure cases
        if "code" not in result or not result["code"]:
            self._restore_stage(view)
            QMessageBox.critical(
                self.workspace._main_window,
                "Error in variable name prediction",
                "Unexpected output returned from the backend. 'code' is not found or empty.",
                QMessageBox.StandardButton.Ok,
            )
            return
        if "predictions" not in result["code"][0] or not result["code"][0]["predictions"]:
            QMessageBox.critical(
                self.workspace._main_window,
                "Error in variable name prediction",
                "Unexpected output returned from the backend. 'predictions' is not found or empty.",
                QMessageBox.StandardButton.Ok,
            )
            self._restore_stage(view)
            return
        if len(result["code"][0]["predictions"]) == 1 and isinstance(result["code"][0]["predictions"][0], str):
            QMessageBox.critical(
                self.workspace._main_window,
                "Error in variable name prediction",
                f"Prediction failed. Error: {result['code'][0]['predictions'][0]}",
                QMessageBox.StandardButton.Ok,
            )
            self._restore_stage(view)
            return

        for idx, m in enumerate(re.finditer(r"@@(\S+)@@(\S+)@@", view.codegen.text)):
            var_name = m.group(1)
            prediction = result["code"][0]["predictions"][0][idx]
            topk = prediction["top-k"]
            # remove variable names that we don't like
            filtered_topk = [item for item in topk if item["pred_name"] not in varname_blacklist]
            if filtered_topk:
                varname_to_predicted[var_name].extend(filtered_topk)

        ctrs = defaultdict(itertools.count)

        # rename them all
        used_names = set()
        for v in view.codegen._variable_kb.variables[view.function.addr]._unified_variables:
            m = re.match(r"@@(\S+)@@\S+@@", v.name)
            if m is not None:
                var_name = m.group(1)
                predicted = varname_to_predicted[var_name]
                predicted = sorted(predicted, key=lambda x: x["confidence"], reverse=True)
                v.candidate_names = {pred["pred_name"] for pred in predicted}
                for pred in predicted:
                    if pred["pred_name"] not in used_names:
                        v.name = pred["pred_name"]
                        used_names.add(v.name)
                        break
                else:
                    if predicted:
                        v.name = predicted[0]["pred_name"] + "_" + str(next(ctrs[predicted[0]["pred_name"]]))
                    else:
                        v.name = var_name  # restore the original name
        view.codegen.am_event()
